9. Iterative CS-based MR image reconstruction from Cartesian data¶

In this tutorial we will reconstruct an MR image from Cartesian under-sampled kspace measurements.

We use the toy datasets available in pysap, more specifically a 2D brain slice and the cartesian acquisition scheme. We compare zero-order image reconstruction with Compressed sensing reconstructions (analysis vs synthesis formulation) using the FISTA algorithm for the synthesis formulation and the Condat-Vu algorithm for the analysis formulation.Sparsity will be promoted in the wavelet domain, using either Symmlet-8 (analysis and synthesis) or undecimated bi-orthogonal wavelets (analysis only).

We remind that the synthesis formulation reads (minimization in the sparsifying domain):

$$ \widehat{z} = \text{arg}\,\min_{z\in C^n_\Psi} \frac{1}{2} \|y - \Omega F \Psi^*z \|_2^2 + \lambda \|z\|_1 $$

and the image solution is given by $\widehat{x} = \Psi^*\widehat{z}$. For an orthonormal wavelet transform, we have $n_\Psi=n$ while for a frame we may have $n_\Psi > n$.

while the analysis formulation consists in minimizing the following cost function (min. in the image domain):

$$ \widehat{x} = \text{arg}\,\min_{x\in C^n} \frac{1}{2} \|y - \Omega F x\|_2^2 + \lambda \|\Psi x\|_1 \,. $$

  • Author: Chaithya G R & Philippe Ciuciu
  • Date: 01/06/2021
  • Update: 04/02/2025
  • Target: ATSI MSc students, Paris-Saclay University
pip install git+https://github.com/CEA-COSMIC/pysap-mri.git

pip install git+https://github.com/CEA-COSMIC/pysap.git

pip install git+https://github.com/paquiteau/brainweb-dl.git
In [240]:
#DISPLAY BRAIN PHANTOM
'''
在医学成像(如 MRI、CT)和信号处理中,幻影(Phantom)是指用于测试、校准或研究的人工数据或物理模型。
它通常是一个合成(模拟)图像或实验装置,用来模拟真实世界的结构,比如大脑、心脏或其他人体组织。
'''
# 1. 导入必要的库

%matplotlib inline

import numpy as np
from mri.operators.utils import convert_mask_to_locations       # 采样掩码转换
from mri.reconstructors import SingleChannelReconstructor
from mri.operators import FFT, WaveletN, WaveletUD2             # 小波变换
import os.path as op
import os
import math ; import cmath
import matplotlib.pyplot as plt
import sys
from modopt.math.metrics import ssim                            # 计算结构相似度 SSIM
from modopt.opt.proximity import SparseThreshold                # 优化 MRI 图像重建
from modopt.opt.linear import Identity                          # 优化 MRI 图像重建

from skimage import data, io, filters

import pywt as pw
import matplotlib.pyplot as plt

import brainweb_dl as bwdl                                      # 用于下载 BrainWeb 数据库中的 MRI 幻影图像

# 2. 设置 Matplotlib 显示参数
plt.rcParams["image.origin"]="lower"                            # 设置图像的原点在左下角(默认是左上角)。
plt.rcParams["image.cmap"]='Greys_r'                            # 设置图像颜色映射为灰度反转(黑白)。

# 3. 下载并处理 MRI 图像
mri_img = bwdl.get_mri(4, "T1")[70, ...].astype(np.float32)
'''
bwdl.get_mri(4, "T1") : 从 BrainWeb 数据库下载一个 T1 加权 MRI 扫描
70 : 获取 MRI 体数据(3D)中的第 70 层切片,即 2D 切片图像。
.astype(np.float32) : 转换数据类型为 float32,以便进行计算。
''' 
#mri_img = bwdl.get_mri(4, "T2")[120, ...].astype(np.float32)

# 4. 显示 MRI 图像
print(mri_img.shape)                                            # (256, 256)
img_size = mri_img.shape[0]                                     # 获取图像的宽度(假设是正方形图像)。

# 5. 绘制 MRI 图像
plt.figure()
plt.imshow(abs(mri_img))                                        # 显示 MRI 图像,abs() 主要用于防止负像素值影响显示。
plt.title("Original brain image")
plt.show()
(256, 256)
No description has been provided for this image
In [241]:
#image = get_sample_data('2d-mri')
# Obtain K-Space Cartesian Mask
#mask = get_sample_data("cartesian-mri-mask")

# 1. 获取 MRI 图像
image = mri_img 

# 2. 生成随机相位编码位置
from mrinufft.trajectories.tools import get_random_loc_1d

# 3. 生成相位编码采样位置
'''
image.shape[0]:图像的宽度/高度(假设是 正方形 MRI 图像)
accel=8:加速因子 = 8,即减少 8 倍的 K-Space 采样(加速 MRI 采集)
center_prop=0.1:中心 10% 的 K-Space 总是被采样(通常低频区域更重要)。
pdf='gaussian':按照高斯分布进行随机采样(中心采样密集,边缘较稀疏)
'''
phase_encoding_locs = get_random_loc_1d(image.shape[0], accel=8, center_prop=0.1, pdf='gaussian')
print(f"phase_encoding_locs: \n {phase_encoding_locs}"  +
      f"\n\nmin(phase_encoding_locs): \n {min(phase_encoding_locs)}" +
      f"\n\nmax(phase_encoding_locs): \n {max(phase_encoding_locs)}")
#print(phase_encoding_locs, min(phase_encoding_locs), max(phase_encoding_locs))

# 4. 调整相位编码位置
'''
+0.5:偏移到中心
* image.shape[0]:映射到像素索引范围(0 ~ image.shape[0])
.astype(int):转换为整数索引
'''
phase_encoding_locs = ((phase_encoding_locs +0.5) * image.shape[0]).astype(int)

# 5. 生成 K-Space 采样掩码
mask = np.zeros(image.shape, dtype=bool)                            # 创建一个全零的掩码
mask[phase_encoding_locs] = 1                                       # 只在选定的位置设置为 1
phase_encoding_locs: 
 [ 0.          0.00390625 -0.00390625  0.0078125  -0.0078125   0.01171875
 -0.01171875  0.015625   -0.015625    0.01953125 -0.01953125  0.0234375
 -0.0234375   0.02734375 -0.02734375  0.03125    -0.03125     0.03515625
 -0.03515625  0.0390625  -0.0390625   0.04296875 -0.04296875  0.07421875
 -0.046875    0.09765625 -0.05078125  0.13671875 -0.05859375  0.14453125
 -0.08203125  0.1640625  -0.0859375   0.19140625 -0.09375     0.3125
 -0.10546875  0.32421875 -0.11328125  0.3671875  -0.140625   -0.15234375
 -0.1640625  -0.171875   -0.17578125 -0.18359375 -0.21484375 -0.2265625
 -0.2421875  -0.265625   -0.2734375  -0.296875   -0.30859375]

min(phase_encoding_locs): 
 -0.30859375

max(phase_encoding_locs): 
 0.3671875
In [242]:
print(phase_encoding_locs)
[128 129 127 130 126 131 125 132 124 133 123 134 122 135 121 136 120 137
 119 138 118 139 117 147 116 153 115 163 113 165 107 170 106 177 104 208
 101 211  99 222  92  89  86  84  83  81  73  70  66  60  58  52  49]
In [243]:
print( phase_encoding_locs.size - image.shape[0]//8 - int(image.shape[0]*0.1)   )  
-4
In [244]:
plt.subplot(1, 2, 1)
plt.imshow(np.abs(image), cmap='gray')
plt.title("MRI Data")
plt.subplot(1, 2, 2)
plt.imshow(mask, cmap='gray')
plt.title("K-space Sampling Mask")
plt.show()
No description has been provided for this image

$\leadsto$   The code above is to generate the mask, the key part is

phase_encoding_locs = get_random_loc_1d(image.shape[0], accel=8, center_prop=0.1, pdf='gaussian')

where:

  • image.shape[0] represents the spatial size of the 1D mask vector.
  • accel=8 means that for every 8 lines, only 1 is sampled, resulting in approximately image.shape[0]//8 = 32 sampled points.
  • center_prop=0.1 ensures that the central $10\%$ of the k-space is fully sampled, preserving low-frequency details, which gives us int(image.shape[0]*0.1) = 25 lines.
  • pdf='gaussian applies a Gaussian sampling distribution, where lower frequencies have a higher sampling probability than higher frequencies.

As a result, the final phase_encoding_locs.size = 53 with $4$ additional lines in the central $10\%$ being sampled redundantly.

Generate the kspace¶

From the 2D brain slice and the acquisition mask, we retrospectively undersample the k-space using a cartesian acquisition mask. We then reconstruct the zero order solution as a baseline

Get the locations of the kspace samples

In [245]:
# 1. Convert the mask into k-space sample locations

# Get the locations of the kspace samples
kspace_loc = convert_mask_to_locations(mask)                        # 2D binary mask -> a list of sampled k-space coordinates (x, y)

# Generate the subsampled kspace

# 2. Define the Fourier operator (FFT)
fourier_op = FFT(samples=kspace_loc, shape=image.shape)

# 3. Generate the undersampled k-space data
'''
.op(image) applies the forward Fourier transform (FFT) only at the sampled locations
kspace_data is a subsampled k-space representation of the image.
'''
kspace_data = fourier_op.op(image)
In [246]:
print(kspace_loc)
[[-0.30859375 -0.5       ]
 [-0.30859375 -0.49609375]
 [-0.30859375 -0.4921875 ]
 ...
 [ 0.3671875   0.48828125]
 [ 0.3671875   0.4921875 ]
 [ 0.3671875   0.49609375]]

Zero order solution

In [247]:
# Perform the inverse Fourier transform
zero_soln = fourier_op.adj_op(kspace_data)
base_ssim = ssim(zero_soln, image)
plt.imshow(np.abs(zero_soln), cmap='gray')
plt.title('Zero Order Solution : SSIM = ' + str(np.around(base_ssim, 3)))
plt.show()
No description has been provided for this image

$\leadsto$   Here we apply the Fourier operator and its ajoint to perform the zero order solution :

  • fourier_op = FFT(samples=kspace_loc, shape=image.shape)

    where we define the Fourier transform (FFT) operator only for the sampled locations ( kspace_loc ) ,

    then we apply this operator to the image : kspace_data = fourier_op.op(image) to obtain the undersampled k-space data.

  • To compute the zero order solution, we just apply the ajoint of the operator zero_soln = fourier_op.adj_op(kspace_data), which serves as a basic reconstruction:

Synthesis formulation: FISTA vs POGM optimization¶

We now want to refine the zero-order solution using compressed sensing reconstruction. Here we adopt the synthesis formulation based on the FISTA algorithm. The cost function is set to Proximity Cost + Gradient Cost

FISTA¶

In [248]:
# Setup the operators
linear_op = WaveletN(wavelet_name="sym8", nb_scales=4)
regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")           
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
    linear_op=linear_op,                                                    # Symmlet-8 wavelet transform with 4 scales for multi-scale sparse representation.
    regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
    gradient_formulation='synthesis',                                       # using Synthesis Formulation
    verbose=1,
)
Lipschitz constant is 1.1000000083858883
WARNING: Making input data immutable.
In [ ]:
image_rec, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
    optimization_alg='fista',                                               # fista
    num_iterations=200,
)
recon_ssim = ssim(image_rec, image)
plt.imshow(np.abs(image_rec), cmap='gray')
plt.title('Iterative Reconstruction FISTA: SSIM = ' + str(np.around(recon_ssim, 3)))
plt.show()

image.png

POGM optimization¶

In [ ]:
image_rec2, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
    optimization_alg='pogm',                                                # pogm
    num_iterations=200,
)
recon2_ssim = ssim(image_rec2, image)
plt.imshow(np.abs(image_rec2), cmap='gray')
plt.title('Iterative Reconstruction POGM: SSIM = ' + str(np.around(recon2_ssim, 3)))
plt.show()

image.png

Q1  

  • Synthesis CS reconstruction: Compare the two optimization algorithms, namely FISTA and POGM.

    Make your own comments on

    • the final image quality,

    • convergence speed (play with the number of iterations),

    • computation cost, etc.

    • Additionally, you may vary the number of scales in the wavelet transform (i.e. WaveletN(wavelet_name="sym8", nb_scales=4)).

In [251]:
def reconstruction (optimization_alg_, num_iterations_, nb_scales_):
    if optimization_alg_ == "fista":
        # Setup the operators
        linear_op = WaveletN(wavelet_name="sym8", nb_scales=nb_scales_)
        regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")           
        # Setup Reconstructor
        reconstructor = SingleChannelReconstructor(
            fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
            linear_op=linear_op,                                                    # Symmlet-8 wavelet transform with 4 scales for multi-scale sparse representation.
            regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
            gradient_formulation='synthesis',                                       # using Synthesis Formulation
            verbose=1,
        )

        image_rec, costs, metrics = reconstructor.reconstruct(
            kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
            optimization_alg='fista',                                               # fista
            num_iterations=num_iterations_,
        )

    elif optimization_alg_ == "pogm":
        # Setup the operators
        linear_op = WaveletN(wavelet_name="sym8", nb_scales=nb_scales_)
        regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")           
        # Setup Reconstructor
        reconstructor = SingleChannelReconstructor(
            fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
            linear_op=linear_op,                                                    # Symmlet-8 wavelet transform with 4 scales for multi-scale sparse representation.
            regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
            gradient_formulation='synthesis',                                       # using Synthesis Formulation
            verbose=1,
        )
        image_rec, costs, metrics = reconstructor.reconstruct(
            kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
            optimization_alg='pogm',                                                # pogm
            num_iterations=num_iterations_,
        )

    elif optimization_alg_ == "condatvu":
        linear_op = WaveletUD2(
            wavelet_id=24,
            nb_scale=nb_scales_,
        )
        regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")

        reconstructor = SingleChannelReconstructor(
            fourier_op=fourier_op,                                                      # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
            linear_op=linear_op,                                                        # UnDecimated Bi-Orthogonal Wavelets 未降采样*双正交小波 进行稀疏性约束, 4 级小波分解
            regularizer_op=regularizer_op,                                              # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm             
            gradient_formulation='analysis',                                            # using Analysis Formulation
            verbose=1,
        )

        image_rec, costs, metrics = reconstructor.reconstruct(
            kspace_data=kspace_data,                                                    # kspace_data = fourier_op.op(image)
            optimization_alg='condatvu',                                                # Condat-Vu
            num_iterations=num_iterations_,
        )

    recon_ssim = ssim(image_rec, image)

    return image_rec, recon_ssim
In [ ]:
vector_optimization_alg = ["fista", "pogm", "condatvu"]
vector_num_iterations = [20, 200, 500, 1000]
vector_nb_scales = [2, 4, 8, 16]


for p in ["fista", "pogm"]:
    optimization_alg_ = p

    fig, axs = plt.subplots(len(vector_num_iterations), len(vector_nb_scales), figsize=(12, 12))
    for i in range(len(vector_num_iterations)):
        num_iterations_ = vector_num_iterations[i]

        for j in range(len(vector_nb_scales)):
            nb_scales_ = vector_nb_scales[j]

            image_rec, recon_ssim = reconstruction (optimization_alg_, num_iterations_, nb_scales_)

            axs[i, j].imshow(np.abs(image_rec), cmap='gray')
            axs[i, j].set_xlabel(f"nb_scales = {nb_scales_} SSIM = {recon_ssim:.3f}")
            axs[i, j].xaxis.set_label_position('top')  # 移动 xlabel 到顶部
            axs[i, j].set_ylabel(f"nb_iter = {num_iterations_}")

    # Add a main title (suptitle) to the figure 
    fig.suptitle(f"Reconstruction Results for Optimization Alg: {p}", fontsize=16)

    # Display the figure with adjusted layout
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust layout to fit suptitle
    plt.show()

image.png

image.png

$\leadsto$   Synthesis CS reconstruction: Compare the two optimization algorithms, namely FISTA and POGM.

$\circ$ When we vary the number of scales nb_scales in the wavelet transform, the final image quality ( SSIM ) remains unchanged up to three decimal places in both cases.

$\circ$ Initially, as num_iterations increases, the quality of the reconstructed image improves until num_iterations ≃ 1000, after which the image quality starts to degrade, in both cases

$\circ$ For the the final image quality, under the same conditions ( or nb_scales, num_iterations ), the difference between the two algorithms is generally insignificant in most cases.

In [ ]:
vector_optimization_alg = ["fista", "pogm", "condatvu"]
vector_num_iterations = np.concatenate([
    np.linspace(20, 1000, 50), 
    np.linspace(1001, 10000, 10)
]) # [20, 200, 500, 1000]
nb_scales_ = 4

vector_SSIM_FISTA = []
vector_SSIM_POGM = []

optimization_alg_ = "fista"
for i in range(len(vector_num_iterations)):
    num_iterations_ = int(vector_num_iterations[i])
    image_rec, recon_ssim = reconstruction (optimization_alg_, num_iterations_, nb_scales_)
    vector_SSIM_FISTA.append(recon_ssim)

optimization_alg_ = "pogm"
for j in range(len(vector_num_iterations)):
    num_iterations_ = int(vector_num_iterations[j])
    image_rec, recon_ssim = reconstruction (optimization_alg_, num_iterations_, nb_scales_)
    vector_SSIM_POGM.append(recon_ssim)

plt.figure(figsize=(10, 6))

# FISTA plot
plt.plot(vector_num_iterations, vector_SSIM_FISTA, linestyle='-', color='red', 
        markersize=8, markerfacecolor='red', label='FISTA')

# POGM plot
plt.plot(vector_num_iterations, vector_SSIM_POGM, linestyle='-', color='blue', 
        markersize=8, markerfacecolor='blue', label='POGM')

# Labels and title
plt.xlabel('num_iterations')
plt.ylabel('recon_ssim')
plt.title('SSIM Score vs Number of Iterations')

# Grid and legend
# plt.grid(True)
plt.legend()

# Show the plot
plt.show()

image.png

The results show that the maximum SSIM Score for the FISTA algorithm occurs at num_iterations = 940 . Below, we display the reconstructed images for it :

  • num_iterations = 940 (corresponding to the maximum SSIM Score)

In this scenario, we fixed nb_scales = 4 .

In [ ]:
# Find the index and value of the maximum vector_SSIM_FISTA.
max_index_FISTA = np.argmax(vector_SSIM_FISTA)
SSIM_corresponding_max_index_FISTA = vector_SSIM_FISTA[max_index_FISTA]
num_iterations_corresponding_max_index_FISTA = vector_num_iterations[max_index_FISTA]
image_corresponding_max_index_FISTA, recon_ssim_FISTA= reconstruction("fista", int(num_iterations_corresponding_max_index_FISTA), 4)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

axes[0].imshow(np.abs(image_corresponding_max_index_FISTA), cmap='gray')
axes[0].set_title("FISTA" + f"  num_iter : {int(num_iterations_corresponding_max_index_FISTA)}" +  f" : SSIM = {recon_ssim_FISTA:.3f}")
axes[0].axis('off')

axes[1].imshow(np.abs(zero_soln), cmap='gray')
axes[1].set_title('Zero Order Solution : SSIM = ' + str(np.around(base_ssim, 3)))
axes[1].axis('off')

plt.tight_layout()
plt.show()

image.png

The results show that the maximum SSIM Score for the POGM algorithm occurs at num_iterations = 780 . Below, we display the reconstructed images for it :

  • num_iterations = 780 (corresponding to the maximum SSIM Score)

In this scenario, we fixed nb_scales = 4 .

In [ ]:
# Find the index and value of the maximum vector_SSIM_POGM.
max_index_POGM = np.argmax(vector_SSIM_POGM)
SSIM_corresponding_max_index_POGM = vector_SSIM_POGM[max_index_POGM]
num_iterations_corresponding_max_index_POGM = vector_num_iterations[max_index_POGM]
image_corresponding_max_index_POGM, recon_ssim_POGM= reconstruction("pogm", int(num_iterations_corresponding_max_index_POGM), 4)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

axes[0].imshow(np.abs(image_corresponding_max_index_POGM), cmap='gray')
axes[0].set_title("POGM" + f"  num_iter : {int(num_iterations_corresponding_max_index_POGM)}" +  f" : SSIM = {recon_ssim_POGM:.3f}")
axes[0].axis('off')

axes[1].imshow(np.abs(zero_soln), cmap='gray')
axes[1].set_title('Zero Order Solution : SSIM = ' + str(np.around(base_ssim, 3)))
axes[1].axis('off')

plt.tight_layout()
plt.show()

image.png

Analysis formulation: Condat-Vu reconstruction¶

In [256]:
linear_op = WaveletUD2(
    wavelet_id=24,
    nb_scale=4,
)

regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")
In [257]:
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,                                                      # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
    linear_op=linear_op,                                                        # UnDecimated Bi-Orthogonal Wavelets 未降采样*双正交小波 进行稀疏性约束, 4 级小波分解
    regularizer_op=regularizer_op,                                              # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm             
    gradient_formulation='analysis',                                            # using Analysis Formulation
    verbose=1,
)
Lipschitz constant is 1.0999999836083485
WARNING: Making input data immutable.
In [ ]:
image_rec3, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,                                                    # kspace_data = fourier_op.op(image)
    optimization_alg='condatvu',                                                # Condat-Vu
    num_iterations=200,
)
recon3_ssim = ssim(image_rec3, image)
plt.imshow(np.abs(image_rec3), cmap='gray')
plt.title('Iterative Reconstruction Condat-Vu: SSIM = ' + str(np.around(recon3_ssim, 3)))
plt.show()

image.png

Q2  

  • Analysis vs Synthesis CS reconstruction: Use the Condat-Vu algorithm to minimize the analysis formulation.

    • First, use the same sparsifying transform (i.e., linear_op set to Dauchechies 8) and compare the results between the two formulations.

    • Second, use an undecimated wavelet transform (linear_op = WaveletUD2) and rerun the analysis formulation. Comment the results

In [ ]:
# Setup the operators
linear_op = WaveletN(wavelet_name="db4", nb_scales=4)
regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")           
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
    linear_op=linear_op,                                                    # Dauchechies 8 wavelet transform with 4 scales for multi-scale sparse representation.
    regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
    gradient_formulation='synthesis',                                       # using Synthesis Formulation
    verbose=1,
)

image_rec4, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
    optimization_alg='fista',                                               # fista
    num_iterations=200,
)
recon_ssim4 = ssim(image_rec4, image)
plt.imshow(np.abs(image_rec4), cmap='gray')
plt.title('Iterative Reconstruction FISTA (synthesis) WaveletN : SSIM = ' + str(np.around(recon_ssim4, 3)))
plt.show()

image.png

In [ ]:
# Setup the operators
linear_op = WaveletN(wavelet_name="db4", nb_scales=4)
regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")           
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
    linear_op=linear_op,                                                    # Dauchechies 8 wavelet transform with 4 scales for multi-scale sparse representation.
    regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
    gradient_formulation='analysis',                                        # using Synthesis Formulation
    verbose=1,
)

image_rec5, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
    optimization_alg='condatvu',                                            # condatvu
    num_iterations=200,
)
recon_ssim5 = ssim(image_rec5, image)
plt.imshow(np.abs(image_rec5), cmap='gray')
plt.title('Reconstruction Condat-Vu (analysis) WaveletN : SSIM = ' + str(np.around(recon_ssim5, 3)))
plt.show()

image.png

In [ ]:
# Setup the operators
linear_op = WaveletUD2(wavelet_id=24, nb_scale=4)
regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")           
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
    linear_op=linear_op,                                                    # UnDecimated Bi-Orthogonal Wavelets 未降采样*双正交小波 进行稀疏性约束, 4 级小波分解
    regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
    gradient_formulation='analysis',                                        # using Analysis Formulation
    verbose=1,
)

image_rec6, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
    optimization_alg='condatvu',                                            # condatvu
    num_iterations=200,
)
recon_ssim6 = ssim(image_rec6, image)
plt.imshow(np.abs(image_rec6), cmap='gray')
plt.title('Reconstruction Condat-Vu (analysis) WaveletUD2 : SSIM = ' + str(np.around(recon_ssim6, 3)))
plt.show()

image.png

In [262]:
def reconstruction_v1 (Wavelet_, gradient_formulation_, num_iterations_):
    # Setup the operators
    if Wavelet_ == "WaveletN":
        linear_op = WaveletN(wavelet_name="db4", nb_scales=4)                       # Dauchechies 8 wavelet transform with 4 scales for multi-scale sparse representation.
    elif Wavelet_ == "WaveletUD2":
        linear_op = WaveletUD2(wavelet_id=24, nb_scale=4)                           # UnDecimated Bi-Orthogonal Wavelets 未降采样*双正交小波 进行稀疏性约束, 4 级小波分解
        
    regularizer_op = SparseThreshold(Identity(), 0.1, thresh_type="soft")  

    if gradient_formulation_ == "synthesis":
        reconstructor = SingleChannelReconstructor(
            fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
            linear_op=linear_op,                                                    # 
            regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
            gradient_formulation='synthesis',                                       # using Synthesis Formulation
            verbose=1,
        )

        image_rec, costs, metrics = reconstructor.reconstruct(
            kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
            optimization_alg='fista',                                               # fista
            num_iterations=num_iterations_,
        )

    elif gradient_formulation_ == "analysis":
        reconstructor = SingleChannelReconstructor(
            fourier_op=fourier_op,                                                  # the **Fourier transform** (FFT) operator only for the sampled locations ( `kspace_loc` )
            linear_op=linear_op,                                                    # 
            regularizer_op=regularizer_op,                                          # Identity() < - > identity linear transformation, \lambda = 0.1, Soft Thresholding < - > L1 norm
            gradient_formulation='analysis',                                        # using Analysis Formulation
            verbose=1,
        )

        image_rec, costs, metrics = reconstructor.reconstruct(
            kspace_data=kspace_data,                                                # kspace_data = fourier_op.op(image)
            optimization_alg='condatvu',                                            # fista
            num_iterations=num_iterations_,
        )

    recon_ssim = ssim(image_rec, image)

    return image_rec, recon_ssim
        
In [ ]:
vector_Wavelet=["WaveletN", "WaveletN", "WaveletUD2"]
vector_gradient_formulation=["synthesis", "analysis", "analysis"]
vector_num_iterations = [20, 100, 400, 500, 600]


fig, axs = plt.subplots(len(vector_num_iterations), 3, figsize=(12, 12))

for i in range(len(vector_num_iterations)) :
    for j in range(3):
        num_iterations_ = vector_num_iterations[i]

        Wavelet_ = vector_Wavelet[j]
        gradient_formulation_ = vector_gradient_formulation[j]

        image_rec, recon_ssim = reconstruction_v1(Wavelet_, gradient_formulation_, num_iterations_)

        # 显示重建的图像
        axs[i, j].imshow(np.abs(image_rec), cmap='gray')
        
        axs[i, j].set_xlabel(f"{Wavelet_}: SSIM = {recon_ssim:.3f}")
        axs[i, j].xaxis.set_label_position('top')  # 移动 xlabel 到顶部
        axs[i, j].set_ylabel(f"num_iter = {num_iterations_}")
    

axs[0, 0].set_title("'synthesis' + fista'")
axs[0, 1].set_title("'analysis' + 'condatvu'")
axs[0, 2].set_title("'analysis' + 'condatvu'")

# Add a main title (suptitle) to the figure
fig.suptitle("Iterative Reconstruction Results")

# Display the figure with adjusted layout
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust layout to fit suptitle
plt.show()

image.png

$\leadsto$   Analysis vs Synthesis CS reconstruction: Use the Condat-Vu algorithm to minimize the analysis formulation.

$\circ$ First, using the same sparsifying transform ( linear_op = WaveletN(wavelet_name="db4", nb_scales=4) ), the gradient formulation 'synthesis' results in better image quality in the reconstruction compared to the gradient formulation 'analysis'

$\circ$ Second, using an undecimated wavelet transform ( linear_op = WaveletUD2(wavelet_id=24, nb_scale=4) ), and re-running the 'analysis' formulation, this time the image quality in the reconstruction improves but remains slightly lower than the first case (i.e., WaveletN, 'synthesis' ) for num_iterations < 600 . However, when num_iterations > 600 , the image quality in the reconstruction using WaveletUD2 'analysis' formulation is more and more better than using WaveletN 'synthesis' .

In [ ]:
vector_Wavelet=["WaveletN", "WaveletN", "WaveletUD2"]
vector_gradient_formulation=["synthesis", "analysis", "analysis"]
vector_num_iterations = np.concatenate([
    np.linspace(20, 1000, 50), 
    np.linspace(1001, 10000, 10)
]) # [20, 200, 500, 1000]

matrix_recon_ssim = np.zeros((len(vector_num_iterations), 3))

for i in range(len(vector_num_iterations)) :
    for j in range(3):
        num_iterations_ = vector_num_iterations[i]

        Wavelet_ = vector_Wavelet[j]
        gradient_formulation_ = vector_gradient_formulation[j]

        image_rec, recon_ssim = reconstruction_v1(Wavelet_, gradient_formulation_, int(num_iterations_))

        matrix_recon_ssim[i][j] = recon_ssim

plt.figure(figsize=(10, 6))

# WaveletN Synthesis FISTA plot
plt.plot(vector_num_iterations, matrix_recon_ssim[:, 0], linestyle='-', color='red', 
        markersize=8, markerfacecolor='red', label='WaveletN Synthesis FISTA')

# WaveletN Analysis Condat-Vu plot
plt.plot(vector_num_iterations, matrix_recon_ssim[:, 1], linestyle='-', color='blue', 
        markersize=8, markerfacecolor='blue', label='WaveletN Analysis Condat-Vu')

# WaveletUD2 Analysis Condat-Vu plot
plt.plot(vector_num_iterations, matrix_recon_ssim[:, 2], linestyle='-', color='green', 
        markersize=8, markerfacecolor='green', label='WaveletUD2 Analysis Condat-Vu')


# Labels and title
plt.xlabel('num_iterations')
plt.ylabel('recon_ssim')
plt.title('SSIM Score vs Number of Iterations')

# Grid and legend
# plt.grid(True)
plt.legend()

# Show the plot
plt.show()

image.png

Configuration¶

In [265]:
from pysap.extensions import sparse2d
print(sparse2d)
<module 'pysap.extensions.sparse2d' from '/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/pysap/extensions/sparse2d.py'>
In [266]:
import numpy as np
import warnings
import pysap
import pysap.base.utils as utils
from pysap.base.transform import MetaRegister
from pysap.base import image
try:
    import pysparse
except ImportError:  
    warnings.warn("Sparse2d python bindings not found, use binaries.")
    pysparse = None
In [267]:
import pysparse
print(pysparse)
<module 'pysparse' from '/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/pysparse.cpython-311-x86_64-linux-gnu.so'>

WSL: Ubuntu-24.04
Miniconda3-py312_24.11.1-0-Linux-x86_64.sh
VSCodeUserSetup-x64-1.97.2.exe

0. install pysap¶

python

pip install git+https://github.com/CEA-COSMIC/pysap.git

1. install Sparse2D¶

python

from pysap.extensions import sparse2d
print(sparse2d) 

output

/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/pysap/extensions/transform.py:41: UserWarning: Sparse2D Python bindings not found. Any call to a Sparse2D transform or a plug-in method that uses a Sparse2D transform will result in an error. warnings.warn(
/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/pysap/extensions/sparse2d.py:25: UserWarning: Sparse2d python bindings not found, use binaries. warnings.warn("Sparse2d python bindings not found, use binaries.")

bash

git clone https://github.com/CosmoStat/Sparse2D.git
cd Sparse2D
mkdir build
cd build
cmake .. -DCMAKE_INSTALL_PREFIX=../install -DONLY_SPARSE=ON -DUSE_FFTW=OFF -DBUILD_CFITSIO=ON
make -j8
make install

find path of pysparse.cpython-311-x86_64-linux-gnu.so, then add path like

bash

echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/home/projects/tensor/Sparse2D/install/lib' >> ~/.bashrc
source ~/.bashrc

python

from pysap.extensions import sparse2d
print(sparse2d) 

output

<module 'pysap.extensions.sparse2d' from '/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/pysap/extensions/sparse2d.py'>

2. check for dependencies of shared libraries¶

bash

ldd /home/home/projects/tensor/Sparse2D/install/lib/pysparse.cpython-311-x86_64-linux-gnu.so

output

linux-vdso.so.1 (0x00007ffd579f7000)
libcfitsio.so.9 => not found
libpython3.11.so.1.0 => not found
libgomp.so.1 => /home/home/miniconda3/envs/tensor/lib/libgomp.so.1 (0x00007f250db52000)
libstdc++.so.6 => /home/home/miniconda3/envs/tensor/lib/libstdc++.so.6 (0x00007f250d93e000)
libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f250d84c000)
libgcc_s.so.1 => /home/home/miniconda3/envs/tensor/lib/libgcc_s.so.1 (0x00007f250d832000)
libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f250d620000)
libz.so.1 => /home/home/miniconda3/envs/tensor/lib/libz.so.1 (0x00007f250d602000)
libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f250d5fd000)
libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f250d5f8000)
libutil.so.1 => /lib/x86_64-linux-gnu/libutil.so.1 (0x00007f250d5f1000)
/lib64/ld-linux-x86-64.so.2 (0x00007f250e637000)

find path of libcfitsio.so.9 , then add path like

bash

echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/home/projects/tensor/Sparse2D/build/lib' >> ~/.bashrc
source ~/.bashrc

find path of libpython3.11.so.1.0 , then add path like

bash

echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/home/miniconda3/envs/tensor/lib' >> ~/.bashrc
source ~/.bashrc

If can not find libcfitsio.so.9 or libpython3.11.so.1.0 , first do sudo apt-get install libcfitsio-dev

bash

echo $LD_LIBRARY_PATH

output echo $LD_LIBRARY_PATH /usr/local/cuda-12.6/lib64::/home/home/projects/tensor/Sparse2D/build/lib:/home/home/projects/tensor/Sparse2D/install/lib:/home/home/miniconda3/envs/tensor/lib

bash

ldd /home/home/projects/tensor/Sparse2D/install/lib/pysparse.cpython-311-x86_64-linux-gnu.so

output

linux-vdso.so.1 (0x00007ffd72dcf000)
libcfitsio.so.9 => /home/home/projects/tensor/Sparse2D/build/lib/libcfitsio.so.9 (0x00007f0ca309b000)
libpython3.11.so.1.0 => /home/home/miniconda3/envs/tensor/lib/libpython3.11.so.1.0 (0x00007f0ca2ace000)
libgomp.so.1 => /home/home/miniconda3/envs/tensor/lib/libgomp.so.1 (0x00007f0ca2a8a000)
libstdc++.so.6 => /home/home/miniconda3/envs/tensor/lib/libstdc++.so.6 (0x00007f0ca2876000)
libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f0ca2784000)
libgcc_s.so.1 => /home/home/miniconda3/envs/tensor/lib/libgcc_s.so.1 (0x00007f0ca276a000)
libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f0ca2558000)
libz.so.1 => /home/home/miniconda3/envs/tensor/lib/libz.so.1 (0x00007f0ca253a000)
libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f0ca2535000)
libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f0ca2530000)
libutil.so.1 => /lib/x86_64-linux-gnu/libutil.so.1 (0x00007f0ca2529000)
/lib64/ld-linux-x86-64.so.2 (0x00007f0ca356f000)

3. test¶

python

from pysap.extensions import sparse2d
print(sparse2d)

output

<module 'pysap.extensions.sparse2d' from '/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/pysap/extensions/sparse2d.py'>

python

import pysparse
print(pysparse)

output

<module 'pysparse' from '/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/pysparse.cpython-311-x86_64-linux-gnu.so'>